{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Control Ops Tutorial\n",
    "\n",
    "In this tutorial we show how to use control flow operators in Caffe2 and give some details on their underlying implementations."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Conditional Execution"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's start with conditional operator. We will demostrate how to use it in two Caffe2 APIs used for building nets: NetBuilder and Brew."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the first example, we first define several blobs and then use 'If' operator to set value of one of them conditionally depending on values of other blobs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from caffe2.python import workspace\n",
    "from caffe2.python.core import Plan, to_execution_step, Net\n",
    "from caffe2.python.net_builder import ops, NetBuilder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "with NetBuilder() as nb:\n",
    "    ops.Const(0.0, blob_out=\"zero\")\n",
    "    ops.Const(1.0, blob_out=\"one\")\n",
    "    ops.Const(0.5, blob_out=\"x\")\n",
    "    ops.Const(0.0, blob_out=\"y\")\n",
    "    with ops.IfNet(ops.GT([\"x\", \"zero\"])):\n",
    "        ops.Copy(\"one\", \"y\")\n",
    "    with ops.Else():\n",
    "        ops.Copy(\"zero\", \"y\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note the usage of NetBuilder's ops.IfNet and ops.Else calls: ops.IfNet accepts a blob reference or blob name as an input, it expects an input blob to have a scalar value convertible to bool, also note that optional ops.Else is at the same level as ops.IfNet and immediately follows corresponding ops.IfNet. Let's execute resulting net (execution step) and check values of blobs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "('x = ', array(0.5, dtype=float32))\n",
      "('y = ', array(1.0, dtype=float32))\n"
     ]
    }
   ],
   "source": [
    "plan = Plan('if_net_test')\n",
    "plan.AddStep(to_execution_step(nb))\n",
    "ws = workspace.C.Workspace()\n",
    "ws.run(plan)\n",
    "print('x = ', ws.blobs[\"x\"].fetch())\n",
    "print('y = ', ws.blobs[\"y\"].fetch())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Before going further, it's important to understand the semantics of execution blocks ('then' and 'else' branches in the example above), i.e. handling of reads and writes into global (defined outside of the block) and local (defined inside the block) blobs."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "NetBuilder's uses the following set of rules:\n",
    " - In NetBuilder's syntax, blob's declaration and definition occur at the same time - when we define an operator which writes its output into a blob with a given name;\n",
    " - NetBuilder keeps track of all operator outputs seen before current execution point in the same block and up the stack in parent blocks.\n",
    " - If an operator writes into a previously unseen blob, it creates a **local** blob that is visible only within the current block and the subsequent children blocks. Local blobs created in a given block are effectively deleted when we exit the block. Any write into previously defined (in the same block or in the parent blocks) blob updates an originally created blob and does not result in the redefinition of a blob.\n",
    " - Operator's input blobs have to be defined earlier in the same block or in the parent blocks up the stack. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As a result, in order to see values computed by a block after the block is finished the corresponding blobs have to be defined outside of the block. Note, that this is one of the ways to solve the problem with uninitialized blobs (e.g. blob created by 'then' branch, but not by 'else' branch), these rules effectively force visible blobs to be always correctly initialized."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To illustrate concepts of block semantics and provide a more sophisticated example, let's consider the following net:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "with NetBuilder() as nb:\n",
    "    ops.Const(0.0, blob_out=\"zero\")\n",
    "    ops.Const(1.0, blob_out=\"one\")\n",
    "    ops.Const(2.0, blob_out=\"two\")\n",
    "    ops.Const(1.5, blob_out=\"x\")\n",
    "    ops.Const(0.0, blob_out=\"y\")\n",
    "    with ops.IfNet(ops.GT([\"x\", \"zero\"])):\n",
    "        ops.Copy(\"x\", \"local_blob\")\n",
    "        with ops.IfNet(ops.LE([\"local_blob\", \"one\"])):\n",
    "            ops.Copy(\"one\", \"y\")\n",
    "        with ops.Else():\n",
    "            ops.Copy(\"two\", \"y\")\n",
    "    with ops.Else():\n",
    "        ops.Copy(\"zero\", \"y\")\n",
    "        # ops.Copy(\"local_blob\", \"z\") - fails with exception during net construction, local_blob is undefined"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "('x = ', array(1.5, dtype=float32))\n",
      "('y = ', array(2.0, dtype=float32))\n"
     ]
    }
   ],
   "source": [
    "plan = Plan('if_net_test_2')\n",
    "plan.AddStep(to_execution_step(nb))\n",
    "ws = workspace.C.Workspace()\n",
    "ws.run(plan)\n",
    "print('x = ', ws.blobs[\"x\"].fetch())\n",
    "print('y = ', ws.blobs[\"y\"].fetch())\n",
    "assert \"local_blob\" not in ws.blobs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Brew is another Caffe2 interface used to construct nets. Unlike NetBuilder, Brew does not track the hierarchy of blocks and, as a result, we need to specify which blobs are considered local and which global when passing 'then' and 'else' models to an API call:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from caffe2.python import brew\n",
    "from caffe2.python.workspace import FeedBlob, RunNetOnce, FetchBlob\n",
    "from caffe2.python.model_helper import ModelHelper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model = ModelHelper(name=\"test_if_model\")\n",
    "\n",
    "model.param_init_net.ConstantFill([], [\"zero\"], shape=[1], value=0.0)\n",
    "model.param_init_net.ConstantFill([], [\"one\"], shape=[1], value=1.0)\n",
    "model.param_init_net.ConstantFill([], [\"x\"], shape=[1], value=0.5)\n",
    "model.param_init_net.ConstantFill([], [\"y\"], shape=[1], value=0.0)\n",
    "model.param_init_net.GT([\"x\", \"zero\"], \"cond\")\n",
    "\n",
    "then_model = ModelHelper(name=\"then_test_model\")\n",
    "then_model.net.Copy(\"one\", \"y\")\n",
    "\n",
    "else_model = ModelHelper(name=\"else_test_model\")\n",
    "else_model.net.Copy(\"zero\", \"y\")\n",
    "\n",
    "brew.cond(\n",
    "    model=model,\n",
    "    cond_blob=\"cond\", # blob with condition value\n",
    "    external_blobs=[\"x\", \"y\", \"zero\", \"one\"], # writes into these blobs update existing blobs\n",
    "    then_model=then_model,\n",
    "    else_model=else_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To run resulting init and main net:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "('x = ', array([ 0.5], dtype=float32))\n",
      "('y = ', array([ 1.], dtype=float32))\n"
     ]
    }
   ],
   "source": [
    "RunNetOnce(model.param_init_net)\n",
    "RunNetOnce(model.net)\n",
    "print(\"x = \", FetchBlob(\"x\"))\n",
    "print(\"y = \", FetchBlob(\"y\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Loops"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Another important control flow operator is 'While' that allows repeated execution of a fragment of net. Let's consider NetBuilder's version of While first."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "with NetBuilder() as nb:\n",
    "    ops.Const(0, blob_out=\"i\")\n",
    "    ops.Const(0, blob_out=\"y\")\n",
    "    with ops.WhileNet():\n",
    "        with ops.Condition():\n",
    "            ops.Add([\"i\", ops.Const(1)], [\"i\"])\n",
    "            ops.LE([\"i\", ops.Const(7)])\n",
    "        ops.Add([\"i\", \"y\"], [\"y\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This example illustrates the usage of 'while' loop with NetBuilder. As with an 'If' operator, standard block semantic rules apply. Note the usage of ops.Condition clause that should immediately follow ops.WhileNet and contains code that is executed before each iteration. The last operator in the condition clause is expected to have a single boolean output that determines whether the other iteration is executed."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the example above we increment the counter (\"i\") before each iteration and accumulate its values in \"y\" blob, the loop's body is executed 7 times, the resulting blob values:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "('i = ', array(8L))\n",
      "('y = ', array(28L))\n"
     ]
    }
   ],
   "source": [
    "plan = Plan('while_net_test')\n",
    "plan.AddStep(to_execution_step(nb))\n",
    "ws = workspace.C.Workspace()\n",
    "ws.run(plan)\n",
    "print(\"i = \", ws.blobs[\"i\"].fetch())\n",
    "print(\"y = \", ws.blobs[\"y\"].fetch())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Corresponding Brew example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model = ModelHelper(name=\"test_while_model\")\n",
    "\n",
    "model.param_init_net.ConstantFill([], [\"i\"], shape=[1], value=0)\n",
    "model.param_init_net.ConstantFill([], [\"one\"], shape=[1], value=1)\n",
    "model.param_init_net.ConstantFill([], [\"seven\"], shape=[1], value=7)\n",
    "model.param_init_net.ConstantFill([], [\"y\"], shape=[1], value=0)\n",
    "\n",
    "loop_model = ModelHelper(name=\"loop_test_model\")\n",
    "loop_model.net.Add([\"i\", \"y\"], [\"y\"])\n",
    "\n",
    "cond_model = ModelHelper(name=\"cond_test_model\")\n",
    "cond_model.net.Add([\"i\", \"one\"], \"i\")\n",
    "cond_model.net.LE([\"i\", \"seven\"], \"cond\")\n",
    "\n",
    "brew.loop(\n",
    "    model=model,\n",
    "    cond_blob=\"cond\", # explicitly specifying condition blob\n",
    "    external_blobs=[\"cond\", \"i\", \"one\", \"seven\", \"y\"],\n",
    "    loop_model=loop_model,\n",
    "    cond_model=cond_model # condition model is optional\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Corresponding blob values:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "('i = ', array([8]))\n",
      "('y = ', array([28]))\n"
     ]
    }
   ],
   "source": [
    "RunNetOnce(model.param_init_net)\n",
    "RunNetOnce(model.net)\n",
    "print(\"i = \", FetchBlob(\"i\"))\n",
    "print(\"y = \", FetchBlob(\"y\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Backpropagation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Both 'If' and 'While' operators support backpropagation. To illustrate how backpropagation with control ops work, let's consider the following example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "# _use_control_ops=True forces NetBuilder to output single net as a result\n",
    "# x is external for NetBuilder, so letting nb know about it through initial_scope param\n",
    "FeedBlob(\"x\", np.array(0.5, dtype='float32'))\n",
    "with NetBuilder(_use_control_ops=True, initial_scope=[\"x\"]) as nb:\n",
    "    ops.Const(0.0, blob_out=\"zero\")\n",
    "    ops.Const(1.0, blob_out=\"one\")\n",
    "    ops.Const(4.0, blob_out=\"y\")\n",
    "    ops.Const(0.0, blob_out=\"z\")\n",
    "    with ops.IfNet(ops.GT([\"x\", \"zero\"])):\n",
    "        ops.Pow(\"y\", \"z\", exponent=2.0)\n",
    "    with ops.Else():\n",
    "        ops.Pow(\"y\", \"z\", exponent=3.0)\n",
    "\n",
    "assert len(nb.get()) == 1, \"Expected a single net produced\"\n",
    "net = nb.get()[0]\n",
    "grad_map = net.AddGradientOperators([\"z\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Output blob \"z\" as a function of \"y\" depends on the value of blob \"x\", if \"x\" is greater than zero, than \"z = y^2\", otherwise it is \"z = y^3\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "('x = ', array(0.5, dtype=float32))\n",
      "('y = ', array(4.0, dtype=float32))\n",
      "('z = ', array(16.0, dtype=float32))\n",
      "('y_grad = ', array(8.0, dtype=float32))\n"
     ]
    }
   ],
   "source": [
    "RunNetOnce(net)\n",
    "print(\"x = \", FetchBlob(\"x\"))\n",
    "print(\"y = \", FetchBlob(\"y\"))\n",
    "print(\"z = \", FetchBlob(\"z\"))\n",
    "print(\"y_grad = \", FetchBlob(\"y_grad\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, let's change value of blob \"x\" and rerun net:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "('x = ', array(-0.5, dtype=float32))\n",
      "('y = ', array(4.0, dtype=float32))\n",
      "('z = ', array(64.0, dtype=float32))\n",
      "('y_grad = ', array(48.0, dtype=float32))\n"
     ]
    }
   ],
   "source": [
    "FeedBlob(\"x\", np.array(-0.5, dtype='float32'))\n",
    "RunNetOnce(net)\n",
    "print(\"x = \", FetchBlob(\"x\"))\n",
    "print(\"y = \", FetchBlob(\"y\"))\n",
    "print(\"z = \", FetchBlob(\"z\"))\n",
    "print(\"y_grad = \", FetchBlob(\"y_grad\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "An example illustrating backpropagation on the loop:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "with NetBuilder(_use_control_ops=True) as nb:\n",
    "    ops.Copy(ops.Const(0), \"i\")\n",
    "    ops.Copy(ops.Const(1), \"one\")\n",
    "    ops.Copy(ops.Const(2), \"two\")\n",
    "    ops.Copy(ops.Const(2.0), \"x\")\n",
    "    ops.Copy(ops.Const(3.0), \"y\")\n",
    "    ops.Copy(ops.Const(2.0), \"z\")\n",
    "    # computes x^4, y^2, z^3\n",
    "    with ops.WhileNet():\n",
    "        with ops.Condition():\n",
    "            ops.Add([\"i\", \"one\"], \"i\")\n",
    "            ops.LE([\"i\", \"two\"])\n",
    "        ops.Pow(\"x\", \"x\", exponent=2.0)\n",
    "        with ops.IfNet(ops.LT([\"i\", \"two\"])):\n",
    "            ops.Pow(\"y\", \"y\", exponent=2.0)\n",
    "        with ops.Else():\n",
    "            ops.Pow(\"z\", \"z\", exponent=3.0)\n",
    "\n",
    "    ops.Add([\"x\", \"y\"], \"x_plus_y\")\n",
    "    ops.Add([\"x_plus_y\", \"z\"], \"s\")\n",
    "\n",
    "assert len(nb.get()) == 1, \"Expected a single net produced\"\n",
    "net = nb.get()[0]\n",
    "\n",
    "grad_map = net.AddGradientOperators([\"s\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "('x = ', array(16.0, dtype=float32))\n",
      "('x_grad = ', array(32.0, dtype=float32))\n",
      "('y = ', array(9.0, dtype=float32))\n",
      "('y_grad = ', array(6.0, dtype=float32))\n",
      "('z = ', array(8.0, dtype=float32))\n",
      "('z_grad = ', array(12.0, dtype=float32))\n"
     ]
    }
   ],
   "source": [
    "workspace.RunNetOnce(net)\n",
    "print(\"x = \", FetchBlob(\"x\"))\n",
    "print(\"x_grad = \", FetchBlob(\"x_grad\")) # 4x^3\n",
    "print(\"y = \", FetchBlob(\"y\"))\n",
    "print(\"y_grad = \", FetchBlob(\"y_grad\")) # 2y\n",
    "print(\"z = \", FetchBlob(\"z\"))\n",
    "print(\"z_grad = \", FetchBlob(\"z_grad\")) # 3z^2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Implementation Notes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "On the low level, Caffe2 uses the following set of operators to implement forward and backward branching and loops:\n",
    "- If - accepts *then_net* and *else_net* nets as arguments and executes one of them, depending on input condition blob value, nets are executed **in the same** workspace;\n",
    "- While - repeats execution of *loop_net* net passed as argument, net is executed in the same workspace;\n",
    "- Do - special operator that creates a separate inner workspace, setups blob mappings between outer and inner workspaces and runs a net in an inner workspace;\n",
    "- CreateScope/HasScope - special operators that create and keep track of workspaces used by Do operator.\n",
    "\n",
    "Higher level libraries that implement branching and looping (e.g. in NetBuilder, Brew), use these operators to build control flow, e.g. for 'If':\n",
    " - do necessary sanity checks (e.g. determine which blobs are initialized and check that subnet does not read undefined blobs)\n",
    " - wrap 'then' and 'else' branches into Do\n",
    " - setup correct blob mappings by specifying which local names are mapped to outer blobs\n",
    " - prepare scope structure, used by Do operator\n",
    "\n",
    "While 'If' and 'While' Caffe2 ops can be used directly without creating local block workspaces, we encourage users to use higher level Caffe2 interfaces that provide necessary correctness guarantees.\n",
    "\n",
    "Backpropagation for 'While' in general is expensive memory-wise - we have to save local workspace for every iteration of a block, including global blobs visible to the block. It is recommended that users use RecurrentNetwork operator instead in production environments."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
